import time
import itertools
import requests

import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
import seaborn as sns

import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import silhouette_score, silhouette_samples, pairwise_distances, make_scorer, accuracy_score, f1_score, recall_score, precision_score
from sklearn.model_selection import cross_validate
from scipy.spatial.distance import squareform, pdist

from skfda import datasets
from skfda.exploratory.visualization.clustering import (
    ClusterMembershipLinesPlot,
    ClusterMembershipPlot,
    ClusterPlot,
)
from skfda.ml.clustering import FuzzyCMeans, KMeans, AgglomerativeClustering, NearestNeighbors
from skfda.representation.grid import FDataGrid
from skfda.misc.metrics import angular_distance, fisher_rao_distance, fisher_rao_amplitude_distance, fisher_rao_phase_distance, lp_distance, LpDistance, l2_distance, PairwiseMetric
from skfda.ml.classification import KNeighborsClassifier
from skfda.exploratory.outliers import MSPlotOutlierDetector, BoxplotOutlierDetector
import skfda
import somadata


cluster_methods = {'FuzzyCMeans':FuzzyCMeans, 'KMeans':KMeans, 'AgglomerativeClustering':AgglomerativeClustering}
pdists = {'angular_distance':angular_distance, 'l2_distance':l2_distance}

cmap = plt.get_cmap('viridis')
colors = cmap(np.linspace(0, 1, 6))

grid_points = [-3, 7, 13, 20, 26, 33, 51]


# Clinical Data
sheet_names = ['Physical exam','Chemistry','Complete Blood Count ']#,'Coagulation']#,'Urine Chemistry']
df_clin = []

for sn in sheet_names:
    df_clin.append( pd.read_excel('/mnt/d/xeno/Xenotransplant Clinical data ( POD -3 to POD +51 ).xlsx', sheet_name=sn) )
    if sn in ['Chemistry','Complete Blood Count ','Coagulation','Urine Chemistry']:
        df_clin[-1] = df_clin[-1].drop(0)
    df_clin[-1].iloc[:,0] = df_clin[-1][df_clin[-1].columns[0]].astype(int)
    df_clin[-1] = df_clin[-1][df_clin[-1][df_clin[-1].columns[0]].isin([-3, 7, 13, 20, 26, 33, 51])]
    df_clin[-1] = df_clin[-1].sort_values([df_clin[-1].columns[0], df_clin[-1].columns[1]]).drop_duplicates(subset=df_clin[-1].columns[0])
    df_clin[-1] = df_clin[-1].set_index(df_clin[-1].columns[0]).dropna(axis=1)

df_clin = pd.concat(df_clin, axis=1)
df_clin = df_clin[[c for c in df_clin.columns if 'Date' not in c]]
df_clin = df_clin.drop('Blood Pressure (mmHg )- (Average / Day)', axis=1)
df_clin = df_clin.astype(float)
df_clin = df_clin.T
df_clin.columns = ['Pre_Tx', 'D7', 'D13', 'D20', 'D26', 'D33', 'D51']
df_clin.iloc[:,:] = StandardScaler().fit_transform(df_clin.iloc[:,:].T).T


#Metabolomics
df_mb = pd.read_excel('/mnt/e/xeno_metabolomics/MC-TB-115_results.xlsx', sheet_name='targeted_data')#.dropna(subset='KEGG_ID')
df_mb = df_mb.pivot_table(index='Metabolite', columns='Sample_ID', values='normalized_area', aggfunc=sum).drop('pig', axis=1)
df_mb.iloc[:,:] = StandardScaler().fit_transform(df_mb.iloc[:,:].T).T
df_mb = df_mb.dropna(axis=0)
df_mb.columns = ['Pre_Tx_1', 'Pre_Tx', 'D7', 'D13', 'D20', 'D26', 'D33', 'D51']
df_mb = df_mb.drop('Pre_Tx_1', axis=1)

## detect outliers
out_detector = MSPlotOutlierDetector(cutoff_factor=.7,random_state=42)
fd_mb = FDataGrid(df_mb.values, grid_points)
df_mb['outliers'] = out_detector.fit_predict(fd_mb)
df_mb = df_mb[df_mb.outliers==1].drop('outliers', axis=1)


#Somascan
adat = somadata.read_adat('/mnt/e/xeno_somascan/HMS-24-042_v5.0_EDTAPlasma.hybNorm.medNormInt.plateScale.calibrate.anmlQC.qcCheck.anmlSMP.20240827.adat')
col_check = adat.pick_meta(axis=1, names=['ColCheck'])#.reset_index()
col_genes = adat.pick_meta(axis=1, names=['EntrezGeneSymbol'])
col_organism = adat.pick_meta(axis=1, names=['Organism'])
df_ss = adat.pick_meta(axis=1, names=['EntrezGeneSymbol']).reset_index()
df_ss = df_ss[df_ss.SampleDescription.str.contains('Visit')]
df_ss = pd.DataFrame(df_ss)
df_ss = df_ss.loc[:,'CRYBB2':].T
df_ss['col_check'] = col_check.columns
df_ss['col_organism'] = col_organism.columns
df_ss = df_ss[(df_ss.col_check=='PASS') & (df_ss.col_organism=='Human')].drop(['col_check','col_organism'], axis=1)
df_ss = df_ss.reset_index().groupby('EntrezGeneSymbol').median()
df_ss = df_ss.T
df_ss = df_ss.iloc[:,np.where(df_ss.columns!='')[0]]
df_ss.iloc[:,:] = StandardScaler().fit_transform(df_ss.iloc[:,:])
df_ss = df_ss.T
df_ss.columns = ['Pre_Tx','D7', 'D13','D20','D26','D33','D51']

## detect outliers
fd_ss = FDataGrid(df_ss.values, grid_points)
out_detector = MSPlotOutlierDetector(cutoff_factor=.7,random_state=42)
df_ss['outliers'] = out_detector.fit_predict(fd_ss)
df_ss = df_ss[df_ss.outliers==1].drop('outliers', axis=1)


# Clinical + Metabolomics + Somascan
df_clin['type'] = 'clinical'
df_mb['type'] = 'metabolomics'
df_ss['type'] = 'proteomics'

df_clin_mb_ss = pd.concat([df_clin, df_ss, df_mb], axis=0)
fd_clin_mb_ss = FDataGrid(df_clin_mb_ss.drop('type', axis=1).values, grid_points)

## Calcluate pairwise distance to explore silhouette score in different clusters numbers
pdists_mb_ss_dict = dict()
t1 = time.time()
for p in pdists.keys():
    pdists_mb_ss_dict[p] = pairwise_distances(df_mb_ss.iloc[:,:-2], metric=pdists[p])
    np.fill_diagonal(pdists_mb_ss_dict[p], 0)
    print('#', end='')
t2 = time.time()
print('\npairwise_distances calculations!')
print(int((t2-t1)//60), (t2-t1)%60, sep=':')

## Silhouette score calculations
mb_ss = dict()
t1 = time.time()
for m,p in itertools.product(cluster_methods.keys(), pdists_dict.keys()):#pdists_dict.keys():
    mb_ss_aux = dict()
    if m in ['FuzzyCMeans', 'KMeans', 'AgglomerativeClustering']:
        for nc in range(2,30,1):
            if m == 'AgglomerativeClustering':
                clus = cluster_methods[m](n_clusters=nc, metric=pdists[p], linkage=AgglomerativeClustering.LinkageCriterion.AVERAGE)#, max_iter=10000, tol=0.000000001, fuzzifier=3)
                clus.fit(fd_mb_ss)
            else:
                clus = cluster_methods[m](n_clusters=nc, metric=pdists[p], random_state=42)#, max_iter=10000, tol=0.000000001, fuzzifier=3)
                clus.fit(fd_mb_ss)            
            mb_ss_aux[nc] = silhouette_score(pdists_mb_ss_dict[p], labels=clus.labels_, metric='precomputed')
    mb_ss[m+'_'+p] = mb_ss_aux
    
t2 = time.time()
print('silhouette_score calculations!')
print(int((t2-t1)//60), (t2-t1)%60, sep=':')


## Ploting silhouette score
fig = plt.figure(figsize=(10,4))
ax = sns.lineplot(x='index',y='value', hue='variable', lw=1, data=pd.DataFrame(mb_ss).melt(ignore_index=False).reset_index())
ax.set_xticks(list(range(2,30,1)),list(range(2,30,1)))
# ax.set_ylim(-1,1)
for l in pd.DataFrame(mb_ss).idxmax():
    ax.axvline(x=l, lw=1, ls='--', color='black')

ax.legend()

# Clustering in 3 groups using fuzzy c-means
n_clusters = 3
clus = FuzzyCMeans(n_clusters=n_clusters, metric=l2_distance, random_state=42)
clus.fit(fd_clin_mb_ss)
df_clin_mb_ss['FDA_cluster_clin_mb_ss'] = clus.labels_
df_clin_mb_ss[['FDA_cluster_clin_mb_ss_proba_'+str(i) for i in range(n_clusters)]] = clus.predict_proba(fd_clin_mb_ss)
df_clin_mb_ss['FDA_cluster_clin_mb_ss_proba_max'] = df_clin_mb_ss[['FDA_cluster_clin_mb_ss_proba_'+str(i) for i in range(n_clusters)]].max(axis=1)

## Ploting clusters

for c in np.sort(df_clin_mb_ss.FDA_cluster_clin_mb_ss.unique()):
    fig = plt.figure(figsize=(10,4))
    grid_points = [-10, 7, 13, 20, 26, 33, 51]

    ax = fig.add_subplot(1,1,1)
    for y in df_clin_mb_ss[(df_clin_mb_ss.FDA_cluster_clin_mb_ss==c) & (df_clin_mb_ss.FDA_cluster_clin_mb_ss_proba_max>=0.4)].sort_values('type',ascending=False).loc[:,:'type'].values:
        if y[-1] == 'metabolomics':
            alpha = .8
            color=[0.134692, 0.658636, 0.517649, 1.      ]
            lw = 1
        elif y[-1] == 'proteomics':
            alpha = .5
            color=colors[1]
            lw = 1.5
        elif y[-1] == 'clinical':
            alpha = .95
            color='orange'
            lw = 1.5
        ax.plot(grid_points, y[:-1], color=color, alpha=alpha, lw=lw)

    ax.text(0.01,0.01,str(c)+' ('+str(df_clin_mb_ss[df_clin_mb_ss.FDA_cluster_clin_mb_ss==c].shape[0])+')', ha='left', color='black', size=9, alpha=.9, transform=ax.transAxes)
    ax.set_xticks(grid_points, ['Pre_Tx', 'D7', 'D13', 'D20', 'D26', 'D33', 'D51'], rotation=90)
    ax.set_yticks([])
    ax.grid(visible=False)

    legend_elements = []

    for col,l in zip(['orange', [0.134692, 0.658636, 0.517649, 1.      ], colors[1]],
                   ['clinical', 'metabolomics','proteomics']):
        legend_elements.append(Line2D([0], [0], color=col, lw=4, label=l),)

    ax.legend(handles=legend_elements, loc='lower right', ncols=1)

    fig.savefig('FDA_proteomics_metabolomics_clicnical_'+str(c)+'_probaMin040.pdf', format='pdf', dpi=300, bbox_inches='tight')